import random
import os
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning.callbacks import ModelCheckpoint
import hydra
from hydra.utils import instantiate
from collections import defaultdict
import json
import logging
from src.pl_model.N_classification_general import N_ClassificationModel
import torch.nn.functional as F
from src.pl_model.classification_model import ClassificationModel

from src.pl_model.distillation import Distilltion
from src.pl_model.my_distillation import MyDistillation
from src.pl_model.my_distillation_qkt_multi_teachers import MyDistillationMutlipleTeachers_QKT
from src.pl_model.my_distillation_qkt_multi_teachers_fc import MyDistillationMutlipleTeachers_QKT_fc

from pytorch_lightning import Trainer
from src.utils.config_utils import get_configs
from torchvision import datasets, transforms
from src.utils.load_models import load_models, load_model3, load_model_qktD
import wandb
import copy

log = logging.getLogger(__name__)

log_file_path = "training_summary.log"


@hydra.main(config_path="conf", config_name="transfer_config_rl")
def my_app(cfg: DictConfig) -> None:
    global log_file_path

    learner_client = cfg.learner_client
    teacher_client = cfg.teacher_client
    cfg.detailed_testing = True
    log.info(f"Configs:\n{OmegaConf.to_yaml(cfg)}")
    train_cfg = get_configs(cfg.train_exp_id)
    cfg = OmegaConf.merge(train_cfg, cfg)
    cfg.stage1_epochs = cfg.trainer.max_epochs
    if not cfg.stage2_epochs:
        cfg.stage2_epochs = cfg.trainer.max_epochs


    seed = cfg.get('seed', 42)
    set_seed(seed)
    print(f"seed: {seed}")
    total_comm = 0


    data_dists_vectorized = get_data_dists_vectorized(cfg, num_classes=cfg.num_classes)
    print(f"data_dists_vectorized: {data_dists_vectorized}\n")

    log.info(f"Instantiating logger <{cfg.logger._target_}>")
    logger = instantiate(cfg.logger)

    # Calculate pre-transfer accuracy if cfg.measure_pre_transfer_acc is True
    pre_transfer_acc_map = {}
    if cfg.measure_pre_transfer_acc:
        teacher_models = load_models(clients_ids=teacher_client, clients=cfg.clients,
                                     model=instantiate(cfg.model))
        for client_idx in learner_client:
            cfg.learner_client = client_idx
            learner_model = teacher_models[cfg.learner_client]
            log.info(f"Calculating pre-transfer accuracy for client {client_idx}")
            pre_transfer_acc_map[client_idx] = calculate_pre_transfer_accuracy(cfg, logger, learner_model)
    else:
        pre_transfer_acc_map = None


    if cfg.personalized_qkt:
        print("queries based on each client distribution (personalized_qkt)!")
        queries = get_clients_Qs_personalized(cfg, data_dists_vectorized=data_dists_vectorized, num_classes=cfg.num_classes)

    elif cfg.predefined_queries:
        print("predefined_queries!")
        queries = cfg.predefined_queries

    elif cfg.num_classes_to_select == 1:
        print("Single-class queries!")
        queries = get_clients_Qs(cfg, data_dists_vectorized=data_dists_vectorized, num_classes=cfg.num_classes)
    else:
        print("multi-class and variable queries!")
        queries = get_clients_Qs_variable(cfg, data_dists_vectorized=data_dists_vectorized, num_classes=cfg.num_classes)

    print(f"queries: {queries}")
    cfg.queries = queries

    queries_list = OmegaConf.to_container(cfg.queries, resolve=True)
    logger.experiment.summary[f"queries"] = queries_list
    logger.experiment.summary[f"num_classes_to_select"] = cfg.num_classes_to_select
    logger.experiment.summary[f"qkt_multi_teachers?"] = cfg.qkt_multi_teachers
    logger.experiment.summary[f"seed"] = cfg.seed

    total_metrics = {
        'best_val_acc': defaultdict(list),
        'best_simple_weighted_accuracy': defaultdict(list),
        'best_uniform_accuracy': defaultdict(list),
        'best_query_class_acc_gain': defaultdict(list),
        'least_forgetting': defaultdict(list),
        'latest': defaultdict(list),
        # 'auto_cp': defaultdict(list)  # Add auto checkpoint scenario
    }

    run_command = " ".join(["python"] + os.sys.argv)
    with open(log_file_path, 'w') as f:
        f.write(f"Run command: {run_command}\n\n")

    complete_log_file_path = os.path.abspath(log_file_path)
    print(f"Summary and averages logged to: {complete_log_file_path}")

    logger.experiment.summary[f"complete_log_file_path"] = complete_log_file_path


    if cfg.centralized_qkt:
        print(f">> creating the combined model for stage1 (one time and shared by all clients)")
        cfg.stage2 = False
        cfg.with_EWC_fc = False

        # Check if we should use client data for centralized QKT
        if cfg.centralized_qkt_use_client_data:
            # Find the client with the most samples
            client_with_most_samples = None
            max_samples = -1
            for client_id, distribution in data_dists_vectorized.items():
                total_samples = sum(distribution)  # Sum the number of samples across all classes for the client
                if total_samples > max_samples:
                    max_samples = total_samples
                    client_with_most_samples = client_id

            if client_with_most_samples is not None:
                print(f"Using the volunteer_client: {client_with_most_samples} with the most samples: {max_samples}")
                cfg.learner_client = int(
                    client_with_most_samples.split('_')[1])  # Extract the client number from the client name
            else:
                cfg.learner_client = 0  # Default to any client if something went wrong
        else:
            cfg.learner_client = 0  # Any client

        # Perform centralized knowledge transfer
        centralized_stage1_model, step1_results = perform_centralized_knowledge_transfer(cfg, logger,
                                                                                         data_dists_vectorized,
                                                                                         pre_transfer_acc_map=pre_transfer_acc_map)
        cfg.with_CE = True

    for idx, client_idx in enumerate(learner_client):
        cfg.learner_client = client_idx
        log.info(f"Learner_client: {cfg.learner_client}")
        # learner_query = queries[cfg.learner_client]
        learner_query = queries[idx]
        cfg.goal_class = learner_query
        log.info(f"The query for client {cfg.learner_client} is to learn class: {cfg.goal_class}")

        if cfg.filter_with_noise:
            teacher_models = load_models(clients_ids=teacher_client, clients=cfg.clients,
                                         model=instantiate(cfg.model))
            # binary_teachers_weights = get_binary_teachers_weights(teacher_models, data_dists_vectorized,
            #                                                       num_classes=cfg.num_classes)
            #
            # print(f"binary_teachers_weights: {binary_teachers_weights}")
            print(f"Filtering teachers using NOISE ...")  # our approach
            teacher_candidates = my_find_teacher_candidates(cfg, learner_query, teacher_models)
            cfg.teacher_candidates = teacher_candidates

        elif cfg.use_all_teachers:
            print(f"Using all teachers ...")
            cfg.teacher_candidates = teacher_client  # all teachers

        else:
            print(f"Filtering teachers using the ground truth number of samples ...")
            teacher_candidates = find_teacher_candidates(cfg, learner_query, teacher_client,
                                                         data_dists_vectorized=data_dists_vectorized,
                                                         sample_threshold=cfg.teacher_sample_threshold)
            cfg.teacher_candidates = teacher_candidates

        cfg.teacher_client = cfg.teacher_candidates
        total_comm += len(cfg.teacher_candidates)
        log.info(f"using the teachers: {cfg.teacher_client}")
        description = ""

        if cfg.two_stage_qkt:
            description = "two_stage_qkt"
            if cfg.centralized_qkt:
                results = perform_two_stage_knowledge_transfer(cfg, logger, data_dists_vectorized, pre_transfer_acc_map,
                                                               centralized_stage1_model=centralized_stage1_model)
            else:
                results = perform_two_stage_knowledge_transfer(cfg, logger, data_dists_vectorized, pre_transfer_acc_map)

        else:
            results = transfer_between_clients(cfg, logger, data_dists_vectorized, pre_transfer_acc_map)

        if results is not None:
            for scenario, metrics in zip(total_metrics.keys(), results):
                if metrics is not None:
                    for key, value in metrics.items():
                        if value is not None:  # Add this check
                            total_metrics[scenario][key].append(value)

            summary_text = print_summary(cfg, data_dists_vectorized, results, description)
            log_summary(summary_text, log_file_path)
            print(f"Summary logged to: {complete_log_file_path}")
            log.info(f"Done learning for client {cfg.learner_client}!")
            log.info(f"----------------------------------------------")

    num_clients = len(learner_client)
    avg_metrics = {}
    std_metrics = {}
    print(f"total_metrics: {total_metrics}")
    for scenario, metrics in total_metrics.items():
        if metrics:  # Only calculate averages if metrics is not empty
            avg_metrics[scenario] = {metric: np.mean([val for val in values if val is not None]) for metric, values in
                                     metrics.items()}
            std_metrics[scenario] = {metric: np.std([val for val in values if val is not None]) for metric, values in
                                     metrics.items()}
        else:
            avg_metrics[scenario] = {}
            std_metrics[scenario] = {}

    def print_avg_metrics(scenario, metrics, std_devs):
        avg_text = f"---{scenario}---\n"
        for key in metrics.keys():
            avg_text += f"avg_{key}: {metrics[key]} (std_dev: {std_devs[key]})\n"
        avg_text += "-----\n"
        return avg_text

    print()
    print()
    # print("Averages:")
    # print("-----")

    averages_text = "Averages:\n-----\n"
    averages_text += print_avg_metrics("Best Val_Acc model", avg_metrics['best_val_acc'], std_metrics['best_val_acc'])
    averages_text += print_avg_metrics("Best Simple Weighted Accuracy model",
                                       avg_metrics['best_simple_weighted_accuracy'],
                                       std_metrics['best_simple_weighted_accuracy'])
    averages_text += print_avg_metrics("Best Uniform Accuracy model", avg_metrics['best_uniform_accuracy'],
                                       std_metrics['best_uniform_accuracy'])
    averages_text += print_avg_metrics("Best Query_classes_acc_gain model", avg_metrics['best_query_class_acc_gain'],
                                       std_metrics['best_query_class_acc_gain'])
    averages_text += print_avg_metrics("Least forgetting model", avg_metrics['least_forgetting'],
                                       std_metrics['least_forgetting'])
    averages_text += print_avg_metrics("Latest model", avg_metrics['latest'], std_metrics['latest'])
    # if 'auto_cp' in avg_metrics and avg_metrics['auto_cp']:
    #     averages_text += print_avg_metrics("Auto Checkpoint model", avg_metrics['auto_cp'], std_metrics['auto_cp'])

    print(averages_text)

    print(f"--- End of averages ---")

    logger.experiment.summary[f"averages"] = avg_metrics
    logger.experiment.summary[f"total_communication"] = total_comm

    # Print the avg_metrics
    print(f"avg_metrics: {avg_metrics}")

    print("---------------")

    log_summary(averages_text, log_file_path)

    print(f"total_communication: {total_comm}")
    complete_log_file_path = os.path.abspath(log_file_path)
    print(f"Summary and averages logged to: {complete_log_file_path}")
    print()


def log_summary(text, log_file_path):
    with open(log_file_path, 'a') as f:
        f.write(text + "\n")


def calculate_pre_transfer_accuracy(cfg, logger, learner_model):
    """
    Calculate the pre-transfer accuracy for the specified learner client.
    """
    datamodule = initialize_data_module(cfg, cfg.learner_client)
    pl_model = N_ClassificationModel(cfg, learner_model=learner_model)
    trainer = Trainer(logger=logger, checkpoint_callback=False)
    trainer.test(pl_model, datamodule=datamodule)
    return pl_model.per_class_test_acc


def calculate_client_all_accuracies(cfg, per_class_test_acc, data_dists_vectorized, pre_transfer_acc_map):
    learner_client = cfg.learner_client

    client_name = f'client_{learner_client}'
    class_distribution = data_dists_vectorized[client_name]
    query_classes = cfg.goal_class

    total_uniform_acc = total_weighted_acc = total_query_class_acc = total_local_class_acc = 0
    count_uniform_classes = total_weight = total_local_weight = 0

    num_classes = len(per_class_test_acc)
    query_class_acc = [per_class_test_acc[i] for i in query_classes]

    if cfg.measure_pre_transfer_acc:
        pre_transfer_acc = pre_transfer_acc_map[learner_client]
    else:
        train_run_id = cfg.train_exp_id
        api = wandb.Api()
        train_run = api.run(train_run_id)
        train_run_summary = train_run.summary._json_dict
        pre_transfer_acc = train_run_summary[f'client-{learner_client}/per_class_test_acc']

    query_class_acc_gain = [(per_class_test_acc[i] - pre_transfer_acc[i]) for i in query_classes]

    for cls_index in range(num_classes):
        if class_distribution[cls_index] > 0 or cls_index in query_classes:
            total_uniform_acc += per_class_test_acc[cls_index]
            count_uniform_classes += 1

            weight = 1 if cls_index in query_classes else class_distribution[cls_index] / sum(class_distribution)
            total_weighted_acc += weight * per_class_test_acc[cls_index]
            total_weight += weight

            if cls_index not in query_classes:
                local_weight = class_distribution[cls_index] / sum(class_distribution)
                total_local_class_acc += local_weight * per_class_test_acc[cls_index]
                total_local_weight += local_weight

    uniform_accuracy = total_uniform_acc / count_uniform_classes if count_uniform_classes > 0 else 0
    simple_weighted_accuracy = total_weighted_acc / total_weight if total_weight > 0 else 0
    query_classes_accuracy = sum(query_class_acc) / len(query_classes) if query_classes else 0
    query_classes_acc_gain = sum(query_class_acc_gain) / len(query_class_acc_gain) if query_classes else 0
    local_classes_accuracy = total_local_class_acc / total_local_weight if total_local_weight > 0 else 0
    forgetting = sum((per_class_test_acc[j] - pre_transfer_acc[j]) for j in range(len(per_class_test_acc)) if
                     (per_class_test_acc[j] - pre_transfer_acc[j]) < 0) / len(
        [accuracy for accuracy in pre_transfer_acc if accuracy > 0])

    return (uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain,
            local_classes_accuracy, forgetting)


def transfer_between_clients(cfg: DictConfig, logger, data_dists_vectorized, pre_transfer_acc_map):
    log.info(f"Loading the training configs")
    train_cfg = get_configs(cfg.train_exp_id)

    cfg.datamodule.datamodule = train_cfg.old_datamodule
    cfg.datamodule.datamodule.split_function = None

    cfg.datamodule.learner_train_indices = train_cfg["clients"][f"client_{cfg.learner_client}"]["train_data_indices"]
    cfg.datamodule.learner_val_indices = train_cfg["clients"][f"client_{cfg.learner_client}"]["val_data_indices"]

    cfg = OmegaConf.merge(train_cfg, cfg)
    cfg.with_earlystopping = False  # for these experiments
    print(f"cfg.with_earlystopping? {cfg.with_earlystopping}")

    callbacks = []
    metric_to_monitor = f"client-{cfg.learner_client}_from_client-{cfg.teacher_client}/val_acc"
    if "callbacks" in cfg:
        for _, cb_conf in cfg.callbacks.items():
            if "_target_" in cb_conf:
                log.info(f"Instantiating callback <{cb_conf._target_}>")
                if cb_conf._target_ == "pytorch_lightning.callbacks.EarlyStopping":
                    if cfg.with_earlystopping:
                        cb_conf.monitor = metric_to_monitor
                        patience = cb_conf.patience
                        log.info(f"Early-stopping, with patience = {patience}")
                    else:
                        log.info(">> No early-stopping")
                        continue
                elif cb_conf._target_ == "pytorch_lightning.callbacks.ModelCheckpoint":
                    cb_conf.monitor = metric_to_monitor
                    if cfg.learner_client == -1:
                        cb_conf.dirpath = f'{os.path.dirname(cfg.clients[f"client_{cfg.teacher_client}"].model_path)}/transfered_models'
                        cb_conf.filename = f"untrained_client_taught_by_client_{cfg.teacher_client}"
                    else:
                        cb_conf.dirpath = f'{os.path.dirname(cfg.clients[f"client_{cfg.learner_client}"].model_path)}/transfered_models'
                        cb_conf.filename = f"client_{cfg.learner_client}_taught_by_client_{cfg.teacher_client}"
                callbacks.append(instantiate(cb_conf))

    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
    datamodule = instantiate(cfg.datamodule)

    if cfg.data_kd:
        pl_model = Distilltion(cfg=cfg)
    elif cfg.qkt_multi_teachers:
        if cfg.with_EWC_fc:
            pl_model = MyDistillationMutlipleTeachers_QKT_fc(cfg)
        else:
            pl_model = MyDistillationMutlipleTeachers_QKT(cfg)

    elif cfg.my_kd:
        pl_model = MyDistillation(cfg=cfg)

    log.info(f"Instantiating Method <{type(pl_model)}>")

    trainer: Trainer = instantiate(cfg.trainer, logger=logger, callbacks=callbacks)
    if cfg.validate_and_test:
        log.info(f"Performing a sanity validation epoch to record performance before learning")
        trainer.validate(model=pl_model, datamodule=datamodule)
        log.info(f"Performing a sanity test epoch to record performance before learning")
        trainer.test(model=pl_model, datamodule=datamodule)

    pl_model.val_acc_best.reset()

    log.info(f"Starting the learning process")
    trainer.fit(model=pl_model, datamodule=datamodule)

    # Save models after training if cfg.save_models is True
    if cfg.qkt_save_models:
        # Define folder to save all clients' models
        save_folder = os.path.join("saved_models", "all_clients_models")
        os.makedirs(save_folder, exist_ok=True)

        # Save the model for the current learner client
        model_path = save_model(pl_model.learner_model, save_folder, f"client_{cfg.learner_client}_trained_model.pth")

        # Log and print the path
        with open(log_file_path, 'a') as f:
            f.write(f"Trained Model Path for client {cfg.learner_client}: {model_path}\n")
        print(f"Trained Model Path for client {cfg.learner_client}: {model_path}")

    logger.experiment.summary[
        f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/best_val_acc"] = trainer.checkpoint_callback.best_model_score
    logger.experiment.summary[
        f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/best_model_path"] = trainer.checkpoint_callback.best_model_path
    logger.experiment.summary[
        f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/best_epoch"] = trainer.model.current_epoch
    logger.experiment.summary[
        f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/best_step"] = trainer.model.global_step
    logger.experiment.summary[f"trainer.default_root_dir"] = trainer.default_root_dir
    logger.experiment.summary[f"T"] = pl_model.T
    logger.experiment.summary[f"data_dists_vectorized"] = data_dists_vectorized
    logger.experiment.summary[f"learner_client-{cfg.learner_client}_query"] = list(cfg.goal_class)
    logger.experiment.summary[f"learner_client-{cfg.learner_client}_teacher_candidates"] = list(cfg.teacher_candidates)
    print(f"trainer.current_epoch: {trainer.current_epoch}")
    print(f"trainer.model.current_epoch: {trainer.model.current_epoch}")

    if not cfg.qkt_unweighted_teachers:
        logger.experiment.summary[
            f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/alpha"] = pl_model.alpha
    else:
        logger.experiment.summary[
            f"learner_client-{cfg.learner_client}_from_client-{cfg.teacher_client}/qkt_unweighted_teachers"] = cfg.qkt_unweighted_teachers



    log.info(f"Performing a validation epoch to record performance after learning")
    trainer.validate(model=pl_model, datamodule=datamodule)

    log.info(f"Testing the learner client model after learning")

    # auto_cp_results = None
    # if not (cfg.with_EWC_fc or cfg.with_softMask_bn or cfg.with_softMask):
    #     log.info(f"testing the best model auto checkpoint based on the metric_to_monitor...")
    #     auto_cp_results = test_model(cfg, datamodule, None, logger, data_dists_vectorized, trainer.model.current_epoch, description="Auto_CP", auto_checkpoint=True, trainer=trainer, pl_model=pl_model)

    print(f"trainer.default_root_dir: {trainer.default_root_dir}")

    print(f"testing the manual saved checkpoints...")
    best_val_acc_ckpt = os.path.join(trainer.default_root_dir, "best_val_acc.ckpt")
    best_simple_weighted_accuracy_ckpt = os.path.join(trainer.default_root_dir,
                                                      "val_best_simple_weighted_accuracy.ckpt")
    best_uniform_accuracy_ckpt = os.path.join(trainer.default_root_dir, "val_best_uniform_accuracy.ckpt")
    best_query_class_acc_gain_ckpt = os.path.join(trainer.default_root_dir, "val_best_query_class_acc_gain.ckpt")
    least_forgetting_ckpt = os.path.join(trainer.default_root_dir, "val_least_forgetting.ckpt")
    latest_ckpt = os.path.join(trainer.default_root_dir, "latest.ckpt")

    best_val_acc_results = test_model(cfg, datamodule, best_val_acc_ckpt, logger, data_dists_vectorized,
                                      pl_model.best_val_acc_epoch, pre_transfer_acc_map=pre_transfer_acc_map,
                                      description="Best_val_acc")
    best_simple_weighted_accuracy_results = test_model(cfg, datamodule, best_simple_weighted_accuracy_ckpt, logger,
                                                       data_dists_vectorized,
                                                       pl_model.best_simple_weighted_accuracy_epoch,
                                                       pre_transfer_acc_map=pre_transfer_acc_map,
                                                       description="Best_simple_weighted_accuracy")
    best_uniform_accuracy_results = test_model(cfg, datamodule, best_uniform_accuracy_ckpt, logger,
                                               data_dists_vectorized, pl_model.best_uniform_accuracy_epoch,
                                               pre_transfer_acc_map=pre_transfer_acc_map,
                                               description="Best_uniform_accuracy")
    best_query_class_acc_gain_results = test_model(cfg, datamodule, best_query_class_acc_gain_ckpt, logger,
                                                   data_dists_vectorized, pl_model.best_query_class_acc_gain_epoch,
                                                   pre_transfer_acc_map=pre_transfer_acc_map,
                                                   description="Best_query_class_acc_gain")
    least_forgetting_results = test_model(cfg, datamodule, least_forgetting_ckpt, logger, data_dists_vectorized,
                                          pl_model.least_forgetting_epoch, pre_transfer_acc_map=pre_transfer_acc_map,
                                          description="Least_forgetting")

    latest_results = test_model(cfg, datamodule, latest_ckpt, logger, data_dists_vectorized,
                                trainer.model.current_epoch, pre_transfer_acc_map=pre_transfer_acc_map,
                                description="Latest")

    results = (best_val_acc_results, best_simple_weighted_accuracy_results, best_uniform_accuracy_results,
               best_query_class_acc_gain_results, least_forgetting_results, latest_results)

    return results


def perform_two_stage_knowledge_transfer(cfg, logger, data_dists_vectorized, pre_transfer_acc_map,
                                         centralized_stage1_model=None):
    print(f"Performing two_stage_knowledge_transfer...")

    print(f"Learner_client: {cfg.learner_client}")
    print(f"The query for client {cfg.learner_client} is to learn class: {cfg.goal_class}")
    print(f"Teacher candidates for client {cfg.learner_client} are: {cfg.teacher_client}")
    starting_point = cfg.two_stage_starting_point

    if cfg.debug:
        print(f">> DEBUG << Step 2")
        cfg = reset_params(cfg)
        cfg.with_EWC_fc = True
        # cfg.freeze_backbone = True
        initial_student_model, new_student_model, two_stages_qkt_results = perform_knowledge_transfer(cfg, logger,
                                                                                                      data_dists_vectorized,
                                                                                                      pre_transfer_acc_map=pre_transfer_acc_map)
    else:

        if not cfg.centralized_qkt:
            save_folder = os.path.join("saved_models", f"client_{cfg.learner_client}")

            print(f"Step 1:")
            cfg.stage2 = False
            cfg.with_EWC_fc = False
            if starting_point == 'kd':
                cfg.qkt_unweighted_teachers = True
            print(
                f"starting point: ({cfg.two_stage_starting_point}, qkt_unweighted_teachers?{cfg.qkt_unweighted_teachers})")
            initial_student_model, new_student_model_qkt1, step1_results = perform_knowledge_transfer(cfg, logger,
                                                                                                     data_dists_vectorized,
                                                                                                     pre_transfer_acc_map=pre_transfer_acc_map)
            cfg = reset_params(cfg)
            new_student = new_student_model_qkt1

            # Save the models if required
            if cfg.qkt_save_models:
                print(f"Saving models for client {cfg.learner_client} in folder: {save_folder}")
                # Save initial student model
                initial_model_path = save_model(initial_student_model, save_folder,
                                                f"client_{cfg.learner_client}_initial_student_model.pth")
                # Save the model after step 1
                model_after_step1_path = save_model(new_student_model_qkt1, save_folder,
                                                    f"client_{cfg.learner_client}_model_after_step1.pth")

                # Log and print the paths
                with open(log_file_path, 'a') as f:
                    f.write(f"Initial Student Model Path: {initial_model_path}\n")
                    f.write(f"Model After Step 1 Path: {model_after_step1_path}\n")
                print(f"Initial Student Model Path: {initial_model_path}")
                print(f"Model After Step 1 Path: {model_after_step1_path}")


        else:  # if centralized, then take the general stage1 model from param
            print(f"skipping stage1 and using the centralized_stage1_model")
            new_student = copy.deepcopy(centralized_stage1_model)

            initial_student_model = load_models(clients_ids=cfg.learner_client, clients=cfg.clients,
                                                model=instantiate(cfg.model))

        if not cfg.no_head_replacement:
            print(f">> Replacing classifier head")
            modified_model = replace_classifier(copy.deepcopy(new_student), initial_student_model)
            print(f"using the new model and the OLD student head")

        else:  # no_head_replacement
            modified_model = copy.deepcopy(new_student)
            print(f"using the new model with NO head replacement.")

        print(f"Step 2: Perform knowledge transfer using the new model and the old student head")
        cfg = reset_params(cfg)
        cfg.with_EWC_fc = True
        cfg.freeze_backbone = True
        cfg.stage2 = True
        initial_student_model, new_student_model, two_stages_qkt_results = perform_knowledge_transfer(cfg, logger,
                                                                                                      data_dists_vectorized,
                                                                                                      starting_model=modified_model,
                                                                                                      pre_transfer_acc_map=pre_transfer_acc_map)

        # Save the final model after stage 2
        if cfg.qkt_save_models and not cfg.centralized_qkt:
            model_after_stage2_path = save_model(new_student_model, save_folder,
                                                 f"client_{cfg.learner_client}_model_after_stage2.pth")
            with open(log_file_path, 'a') as f:
                f.write(f"Model After Stage 2 Path: {model_after_stage2_path}\n")
            print(f"Model After Stage 2 Path: {model_after_stage2_path}")

        if not cfg.centralized_qkt:
            # print stage1 results
            summary_text = print_summary(cfg, data_dists_vectorized, step1_results, f"step1 ({starting_point}) results",
                                         step1=True)

            log_summary(summary_text, log_file_path)
            complete_log_file_path = os.path.abspath(log_file_path)
            print(f"step1 Summary logged to: {complete_log_file_path}")

    results = two_stages_qkt_results  # + (step1_results, )
    return results


def reset_params(cfg):
    cfg.with_EWC_fc = False
    cfg.data_kd = False
    cfg.qkt_unweighted_teachers = False
    cfg.freeze_backbone = False
    return cfg


def perform_knowledge_transfer(cfg, logger, data_dists_vectorized, starting_model=None, pre_transfer_acc_map=None):
    datamodule = initialize_data_module(cfg, cfg.learner_client)

    if cfg.stage2:
        cfg.num_epochs = cfg.stage2_epochs
    else:
        cfg.num_epochs = cfg.stage1_epochs

    print(f"num_epochs: {cfg.num_epochs}")

    if cfg.centralized_qkt:
        if not cfg.stage2:
            cfg.with_CE = False
            cfg.qkt_unweighted_teachers = True
        else:  # stage2
            cfg.with_CE = True
            cfg.qkt_unweighted_teachers = False

    if cfg.data_kd:
        pl_model = Distilltion(cfg=cfg)
    elif cfg.qkt_multi_teachers:
        if cfg.with_EWC_fc:
            pl_model = MyDistillationMutlipleTeachers_QKT_fc(cfg, l_model=starting_model)
        else:
            pl_model = MyDistillationMutlipleTeachers_QKT(cfg)
    elif cfg.my_kd:
        pl_model = MyDistillation(cfg=cfg)


    log.info(f"Instantiating Method <{type(pl_model)}>")
    trainer: Trainer = instantiate(cfg.trainer, logger=logger,max_epochs=cfg.num_epochs)
    log.info("Starting training")
    trainer.fit(model=pl_model, datamodule=datamodule)
    log.info("Knowledge transfer completed.")
    log.info(f"Performing a validation epoch to record performance of the model after knowledge transfer.")
    trainer.validate(model=pl_model, datamodule=datamodule)

    per_class_val_acc = pl_model.per_class_val_acc
    print(f"per_class_val_acc:{per_class_val_acc}")

    print(f"testing the latest version of the model...")

    trainer.test(model=pl_model, datamodule=datamodule)
    per_class_test_acc = pl_model.per_class_test_acc

    uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain, local_classes_accuracy, forgetting = calculate_client_all_accuracies(
        cfg, per_class_test_acc, data_dists_vectorized, pre_transfer_acc_map)

    latest_results = {
        "best_epoch": trainer.current_epoch,
        "val_per_class_acc": per_class_val_acc,
        "per_class_test_acc": per_class_test_acc,
        "uniform_accuracy": uniform_accuracy,
        "simple_weighted_accuracy": simple_weighted_accuracy,
        "local_classes_accuracy": local_classes_accuracy,
        "query_classes_accuracy": query_classes_accuracy,
        "query_classes_acc_gain": query_classes_acc_gain,
        "forgetting": forgetting,
    }

    print(f"latest results:")
    for key in latest_results:
        print(f"{key}: {latest_results[key]}")

    #########

    print(f"testing the manual saved checkpoints...")
    best_val_acc_ckpt = os.path.join(trainer.default_root_dir, "best_val_acc.ckpt")
    best_simple_weighted_accuracy_ckpt = os.path.join(trainer.default_root_dir,
                                                      "val_best_simple_weighted_accuracy.ckpt")
    best_uniform_accuracy_ckpt = os.path.join(trainer.default_root_dir, "val_best_uniform_accuracy.ckpt")
    best_query_class_acc_gain_ckpt = os.path.join(trainer.default_root_dir, "val_best_query_class_acc_gain.ckpt")
    least_forgetting_ckpt = os.path.join(trainer.default_root_dir, "val_least_forgetting.ckpt")
    # latest_ckpt = os.path.join(trainer.default_root_dir, "latest.ckpt")

    best_val_acc_results = test_model(cfg, datamodule, best_val_acc_ckpt, logger, data_dists_vectorized,
                                      pl_model.best_val_acc_epoch, pre_transfer_acc_map=pre_transfer_acc_map,
                                      description="Best_val_acc")
    best_simple_weighted_accuracy_results = test_model(cfg, datamodule, best_simple_weighted_accuracy_ckpt, logger,
                                                       data_dists_vectorized,
                                                       pl_model.best_simple_weighted_accuracy_epoch,
                                                       pre_transfer_acc_map=pre_transfer_acc_map,
                                                       description="Best_simple_weighted_accuracy")
    best_uniform_accuracy_results = test_model(cfg, datamodule, best_uniform_accuracy_ckpt, logger,
                                               data_dists_vectorized, pl_model.best_uniform_accuracy_epoch,
                                               pre_transfer_acc_map=pre_transfer_acc_map,
                                               description="Best_uniform_accuracy")
    best_query_class_acc_gain_results = test_model(cfg, datamodule, best_query_class_acc_gain_ckpt, logger,
                                                   data_dists_vectorized, pl_model.best_query_class_acc_gain_epoch,
                                                   pre_transfer_acc_map=pre_transfer_acc_map,
                                                   description="Best_query_class_acc_gain")
    least_forgetting_results = test_model(cfg, datamodule, least_forgetting_ckpt, logger, data_dists_vectorized,
                                          pl_model.least_forgetting_epoch, pre_transfer_acc_map=pre_transfer_acc_map,
                                          description="Least_forgetting")


    #########

    results = (best_val_acc_results, best_simple_weighted_accuracy_results, best_uniform_accuracy_results,
               best_query_class_acc_gain_results, least_forgetting_results, latest_results)

    return pl_model.org_learner_model_copy, pl_model.learner_model, results


def test_model(cfg, datamodule, ckpt_path, logger, data_dists_vectorized, best_epoch, pre_transfer_acc_map,
               description=None):
    cfg.test_description = description
    val_per_class_acc = None
    trainer = Trainer(logger=logger)
    log.info(f"Testing using manual checkpoint: {ckpt_path}")
    i_model = load_model_qktD(model=instantiate(cfg.model), model_path=ckpt_path, device=torch.device('cpu'))
    pl_model = N_ClassificationModel(cfg=cfg, learner_model=i_model)
    trainer.test(model=pl_model, datamodule=datamodule)
    per_class_test_acc = pl_model.per_class_test_acc

    uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain, local_classes_accuracy, forgetting = calculate_client_all_accuracies(
        cfg, per_class_test_acc, data_dists_vectorized, pre_transfer_acc_map)

    results = {
        "best_epoch": best_epoch,
        "val_per_class_acc": val_per_class_acc,
        "per_class_test_acc": per_class_test_acc,
        "uniform_accuracy": uniform_accuracy,
        "simple_weighted_accuracy": simple_weighted_accuracy,
        "local_classes_accuracy": local_classes_accuracy,
        "query_classes_accuracy": query_classes_accuracy,
        "query_classes_acc_gain": query_classes_acc_gain,
        "forgetting": forgetting,
    }

    print(f"{description} results:")
    for key in results:
        print(f"{key}: {results[key]}")

    return results


def perform_centralized_knowledge_transfer(cfg, logger, data_dists_vectorized, starting_model=None,
                                           pre_transfer_acc_map=None):
    datamodule = initialize_data_module(cfg, cfg.learner_client)

    if cfg.centralized_qkt:
        if not cfg.stage2:
            cfg.qkt_unweighted_teachers = True
            if not cfg.centralized_qkt_use_client_data: # TODO: consider changing this
                cfg.with_CE = False

    if cfg.data_kd:
        pl_model = Distilltion(cfg=cfg)
    elif cfg.qkt_multi_teachers:
        if cfg.with_EWC_fc:
            pl_model = MyDistillationMutlipleTeachers_QKT_fc(cfg, l_model=starting_model)
        else:
            pl_model = MyDistillationMutlipleTeachers_QKT(cfg)
    elif cfg.my_kd:
        pl_model = MyDistillation(cfg=cfg)

    log.info(f"Instantiating Method <{type(pl_model)}>")
    trainer: Trainer = instantiate(cfg.trainer, logger=logger)
    log.info("Starting training")
    trainer.fit(model=pl_model, datamodule=datamodule)
    log.info("Knowledge transfer completed.")
    log.info(f"Performing a validation epoch to record performance of the model after knowledge transfer.")
    trainer.validate(model=pl_model, datamodule=datamodule)

    per_class_val_acc = pl_model.per_class_val_acc
    print(f"per_class_val_acc:{per_class_val_acc}")

    print(f"testing the latest version of the model...")

    trainer.test(model=pl_model, datamodule=datamodule)
    per_class_test_acc = pl_model.per_class_test_acc

    return pl_model.learner_model, per_class_test_acc


def print_summary(cfg, data_dists_vectorized, results, description=None, step1=False):
    # best_val_acc_results, best_simple_weighted_accuracy_results, best_uniform_accuracy_results, best_query_class_acc_gain_results, least_forgetting_results, latest_results, auto_cp_results = results
    best_val_acc_results, best_simple_weighted_accuracy_results, best_uniform_accuracy_results, best_query_class_acc_gain_results, least_forgetting_results, latest_results = results

    summary_text = f"--- SUMMARY ---\n"
    summary_text += f"-- {description} --\n"
    summary_text += f"Learner_client-{cfg.learner_client}:\n"
    summary_text += f"Query: {list(cfg.goal_class)}\n"
    summary_text += f"from-{cfg.teacher_client}\n"
    summary_text += f"data dist : {data_dists_vectorized}\n"
    summary_text += f"seed: {cfg.seed}\n"
    summary_text += f"--- after learning ---\n"
    summary_text += "-----\n"
    # if auto_cp_results:
    #     summary_text += f"---Auto Checkpoint model---\n"
    #     summary_text += f"val_per_class_acc: {auto_cp_results['val_per_class_acc']}\n"
    #     summary_text += f"best_epoch: {auto_cp_results['best_epoch']}\n"
    #     summary_text += f"per_class_test_acc: {auto_cp_results['per_class_test_acc']}\n"
    #     summary_text += f"Uniform accuracy: {auto_cp_results['uniform_accuracy']}\n"
    #     summary_text += f"Weighted accuracy (Simple): {auto_cp_results['simple_weighted_accuracy']}\n"
    #     summary_text += f"local_classes_accuracy: {auto_cp_results['local_classes_accuracy']}\n"
    #     summary_text += f"query_classes_accuracy: {auto_cp_results['query_classes_accuracy']}\n"
    #     summary_text += f"query_classes_acc_gain: {auto_cp_results['query_classes_acc_gain']}\n"
    #     summary_text += f"forgetting: {auto_cp_results['forgetting']}\n"
    #     summary_text += "-----\n"

    summary_text += f"---Best Val_Acc model---\n"
    summary_text += f"val_per_class_acc: {best_val_acc_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {best_val_acc_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {best_val_acc_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {best_val_acc_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {best_val_acc_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {best_val_acc_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {best_val_acc_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {best_val_acc_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {best_val_acc_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"---Best Simple Weighted Accuracy model---\n"
    summary_text += f"val_per_class_acc: {best_simple_weighted_accuracy_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {best_simple_weighted_accuracy_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {best_simple_weighted_accuracy_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {best_simple_weighted_accuracy_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {best_simple_weighted_accuracy_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {best_simple_weighted_accuracy_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {best_simple_weighted_accuracy_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {best_simple_weighted_accuracy_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {best_simple_weighted_accuracy_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"---Best Uniform Accuracy model---\n"
    summary_text += f"val_per_class_acc: {best_uniform_accuracy_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {best_uniform_accuracy_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {best_uniform_accuracy_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {best_uniform_accuracy_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {best_uniform_accuracy_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {best_uniform_accuracy_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {best_uniform_accuracy_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {best_uniform_accuracy_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {best_uniform_accuracy_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"---Best Query_classes_acc_gain model---\n"
    summary_text += f"val_per_class_acc: {best_query_class_acc_gain_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {best_query_class_acc_gain_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {best_query_class_acc_gain_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {best_query_class_acc_gain_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {best_query_class_acc_gain_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {best_query_class_acc_gain_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {best_query_class_acc_gain_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {best_query_class_acc_gain_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {best_query_class_acc_gain_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"---Least forgetting model---\n"
    summary_text += f"val_per_class_acc: {least_forgetting_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {least_forgetting_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {least_forgetting_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {least_forgetting_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {least_forgetting_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {least_forgetting_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {least_forgetting_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {least_forgetting_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {least_forgetting_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"---Latest run model---\n"
    summary_text += f"val_per_class_acc: {latest_results['val_per_class_acc']}\n"
    summary_text += f"best_epoch: {latest_results['best_epoch']}\n"
    summary_text += f"per_class_test_acc: {latest_results['per_class_test_acc']}\n"
    summary_text += f"Uniform accuracy: {latest_results['uniform_accuracy']}\n"
    summary_text += f"Weighted accuracy (Simple): {latest_results['simple_weighted_accuracy']}\n"
    summary_text += f"local_classes_accuracy: {latest_results['local_classes_accuracy']}\n"
    summary_text += f"query_classes_accuracy: {latest_results['query_classes_accuracy']}\n"
    summary_text += f"query_classes_acc_gain: {latest_results['query_classes_acc_gain']}\n"
    summary_text += f"forgetting: {latest_results['forgetting']}\n"
    summary_text += "-----\n"
    summary_text += f"--- End of SUMMARY ---\n"

    print(summary_text)
    return summary_text


def replace_classifier(model, new_classifier):
    model.fc = copy.deepcopy(new_classifier.fc)
    return model


def initialize_data_module(cfg, current_client=0):
    log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>")
    train_cfg = get_configs(cfg.train_exp_id)
    api = wandb.Api()
    train_run = api.run(cfg.train_exp_id)
    cfg.datamodule.datamodule = train_cfg.old_datamodule
    print(f"cfg.datamodule.datamodule: {cfg.datamodule.datamodule}")
    cfg.datamodule.datamodule.split_function = None
    cfg.datamodule.no_noise = True

    if cfg.centralized_qkt and not cfg.stage2:
        if cfg.centralized_qkt_use_client_data:
            cfg.datamodule.learner_train_indices = train_cfg["clients"][f"client_{cfg.volunteer_client_id}"]["train_data_indices"]
            cfg.datamodule.learner_val_indices = train_cfg["clients"][f"client_{cfg.volunteer_client_id}"]["val_data_indices"]
        else:
            cfg.datamodule.learner_train_indices = train_cfg["clients"][f"client_0"]["val_data_indices"]
            cfg.datamodule.learner_val_indices = train_cfg["clients"][f"client_0"]["val_data_indices"]
    else:
        cfg.datamodule.learner_train_indices = train_cfg["clients"][f"client_{current_client}"]["train_data_indices"]
        cfg.datamodule.learner_val_indices = train_cfg["clients"][f"client_{current_client}"]["val_data_indices"]

    datamodule = instantiate(cfg.datamodule)
    datamodule.setup('fit')
    return datamodule


# def calculate_client_all_accuracies(cfg, per_class_test_acc, data_dists_vectorized):
#     learner_client = cfg.learner_client
#
#     client_name = f'client_{learner_client}'
#     class_distribution = data_dists_vectorized[client_name]
#     query_classes = cfg.goal_class
#
#     total_uniform_acc = total_weighted_acc = total_query_class_acc = total_local_class_acc = 0
#     count_uniform_classes = total_weight = total_local_weight = 0
#
#     num_classes = len(per_class_test_acc)
#     query_class_acc = [per_class_test_acc[i] for i in query_classes]
#
#     train_run_id = cfg.train_exp_id
#     api = wandb.Api()
#     train_run = api.run(train_run_id)
#     train_run_summary = train_run.summary._json_dict
#     pre_transfer_acc = train_run_summary[f'client-{learner_client}/per_class_test_acc']
#
#     query_class_acc_gain = [(per_class_test_acc[i] - pre_transfer_acc[i]) for i in query_classes]
#
#     for cls_index in range(num_classes):
#         if class_distribution[cls_index] > 0 or cls_index in query_classes:
#             total_uniform_acc += per_class_test_acc[cls_index]
#             count_uniform_classes += 1
#
#             weight = 1 if cls_index in query_classes else class_distribution[cls_index] / sum(class_distribution)
#             total_weighted_acc += weight * per_class_test_acc[cls_index]
#             total_weight += weight
#
#             if cls_index not in query_classes:
#                 local_weight = class_distribution[cls_index] / sum(class_distribution)
#                 total_local_class_acc += local_weight * per_class_test_acc[cls_index]
#                 total_local_weight += local_weight
#
#     uniform_accuracy = total_uniform_acc / count_uniform_classes if count_uniform_classes > 0 else 0
#     simple_weighted_accuracy = total_weighted_acc / total_weight if total_weight > 0 else 0
#     query_classes_accuracy = sum(query_class_acc) / len(query_classes) if query_classes else 0
#     query_classes_acc_gain = sum(query_class_acc_gain) / len(query_class_acc_gain) if query_classes else 0
#     local_classes_accuracy = total_local_class_acc / total_local_weight if total_local_weight > 0 else 0
#     forgetting = sum((per_class_test_acc[j] - pre_transfer_acc[j]) for j in range(len(per_class_test_acc)) if
#                      (per_class_test_acc[j] - pre_transfer_acc[j]) < 0) / len(
#         [accuracy for accuracy in pre_transfer_acc if accuracy > 0])
#
#     return (uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain,
#             local_classes_accuracy, forgetting)


def find_teacher_candidates(cfg, learner_query, teacher_clients, data_dists_vectorized, sample_threshold=0):
    teacher_candidates = []
    for teacher in teacher_clients:
        teacher_name = f'client_{teacher}'
        class_distribution = data_dists_vectorized[teacher_name]

        if any(class_distribution[cls] > sample_threshold for cls in learner_query):
            teacher_candidates.append(teacher)

    return teacher_candidates


def my_find_teacher_candidates(cfg, learner_query, teacher_models, threshold=0.05):
    class_presence = defaultdict(list)

    for client_id, model in enumerate(teacher_models):
        synthetic_data, synthetic_labels = create_anonymized_data_impressions(model)
        synthetic_data = synthetic_data.to(next(model.parameters()).device)
        output = model(synthetic_data)
        predicted_probs = F.softmax(output, dim=1)

        for cls in range(predicted_probs.size(1)):
            avg_prob = predicted_probs[:, cls].mean().item()
            class_presence[cls].append((avg_prob, client_id))

    candidates = set()
    for cls in learner_query:
        meaningful_presence = [entry for entry in class_presence[cls] if entry[0] > threshold]
        candidates.update([client_id for _, client_id in meaningful_presence])
        print(f"Teachers candidates with meaningful_presence of the query ({learner_query}) are: {candidates}")
        if cfg.copy_of_self_as_teacher:
            candidates.add(cfg.learner_client)
            print(f"Also added a copy of the learner client {cfg.learner_client} as a teacher. ")
    return list(candidates)


def create_anonymized_data_impressions(model):
    num_classes = model.fc.out_features

    synthetic_data = []
    synthetic_labels = []
    for i in range(num_classes):
        input_shape = (3, 224, 224)  # Adjust this shape based on the model's expected input
        noise_input = torch.randn(input_shape)
        synthetic_data.append(noise_input)
        synthetic_labels.append(i)

    return torch.stack(synthetic_data), torch.tensor(synthetic_labels, dtype=torch.long)


def get_data_dists_vectorized(cfg, num_classes=10):
    train_cfg = get_configs(cfg.train_exp_id)

    total_num_samples_per_class = defaultdict(int)
    data_dists_vectorized = {}
    for client, info in train_cfg["clients"].items():
        data_dist = info["train_data_distribution"]
        data_dist_vectorized = [data_dist.get(str(cls_idx), 0) for cls_idx in range(num_classes)]
        data_dists_vectorized[client] = data_dist_vectorized

        for cls_idx, count in data_dist.items():
            total_num_samples_per_class[cls_idx] += count

    return data_dists_vectorized


def get_clients_Qs(cfg, data_dists_vectorized, num_classes=10, num_clients=10, any_random_class=False,
                   num_classes_to_select=1, sample_threshold=50, seed=42):
    random.seed(seed)
    all_clients = [i for i in range(num_clients)]

    queries = []
    for client_index in all_clients:
        client_name = f'client_{client_index}'
        class_distribution = data_dists_vectorized[client_name]
        print(f"client{client_index}, class_distribution: {class_distribution}")

        if any_random_class:
            possible_classes = list(range(num_classes))
        else:
            possible_classes = [i for i, count in enumerate(class_distribution) if count <= sample_threshold]

        print(f"possible_classes: {possible_classes}")

        if possible_classes:
            selected_classes = random.sample(possible_classes, min(len(possible_classes), num_classes_to_select))
            print(f"selected_classes: {selected_classes}")
            queries.append(selected_classes)
        else:
            queries.append([])

    return queries


def get_clients_Qs_variable(cfg, data_dists_vectorized, num_classes=10, num_clients=10, any_random_class=False,
                            sample_threshold=50, seed=42, max_classes_fraction=1):
    random.seed(seed)
    if cfg.max_classes_fraction > 0:
        max_classes_fraction = cfg.max_classes_fraction

    if num_classes >= 100:
        max_classes_fraction = 0.2

    all_clients = [i for i in range(num_clients)]

    queries = []
    for client_index in all_clients:
        client_name = f'client_{client_index}'
        class_distribution = data_dists_vectorized[client_name]
        print(f"Client {client_index}, class distribution: {class_distribution}")

        if any_random_class:
            possible_classes = list(range(num_classes))
        else:
            possible_classes = [i for i, count in enumerate(class_distribution) if count <= sample_threshold]

        print(f"Possible classes for client {client_index}: {possible_classes}")

        if possible_classes:
            max_classes = int(max_classes_fraction * len(possible_classes))
            num_classes_to_select = random.randint(1, max(max_classes, 1))
            selected_classes = random.sample(possible_classes, num_classes_to_select)
            print(f"Selected classes for client {client_index}: {selected_classes}")
            queries.append(selected_classes)
        else:
            queries.append([])

    return queries




def get_clients_Qs_personalized(cfg, data_dists_vectorized, num_classes=10, num_clients=10):

    all_clients = [i for i in range(num_clients)]

    queries = []
    for client_index in all_clients:
        client_name = f'client_{client_index}'
        class_distribution = data_dists_vectorized[client_name]
        print(f"Client {client_index}, class distribution: {class_distribution}")

        client_classes = [i for i, count in enumerate(class_distribution) if count > 0]

        print(f"classes of client {client_index}: {client_classes}")

        queries.append(client_classes)

    return queries


def get_binary_teachers_weights(client_models, data_dists_vectorized, num_classes=10):
    binary_teachers_weights = []
    for i, model in enumerate(client_models):
        print(f"\nAnalyzing client {i}'s model")
        client_name = f"client_{i}"
        client_dist = data_dists_vectorized[client_name]
        print(f"client_dist: {client_dist}")

        # Identify underrepresented classes
        print(f"Detected stats:")
        binary_teacher_weights = detect_stats(model, threshold=0.01)
    binary_teachers_weights.append(binary_teacher_weights)
    return binary_teachers_weights


def detect_stats(model, threshold=0.05):
    synthetic_data, synthetic_labels = create_anonymized_data_impressions(model)
    synthetic_data = synthetic_data.to(next(model.parameters()).device)  # Move data to the same device as the model
    output = model(synthetic_data)
    predicted_probs = F.softmax(output, dim=1)

    underrepresented_classes = []
    avg_predicted_probs = []
    binary_teacher_weights = []
    for i in range(predicted_probs.size(1)):
        avg_prob = predicted_probs[:, i].mean().item()
        avg_predicted_probs.append(avg_prob)
        if avg_prob < threshold:
            underrepresented_classes.append(i)
            binary_teacher_weights.append(0)
        else:
            binary_teacher_weights.append(1)

    print(f"model's avg_predicted_probs: {avg_predicted_probs}")
    print(f"model's underrepresented_classes: {underrepresented_classes}")
    print(f"model's binary_teacher_weights: {binary_teacher_weights}")

    return binary_teacher_weights


def save_model(model, folder_path, model_name):
    # Create the folder if it doesn't exist
    os.makedirs(folder_path, exist_ok=True)
    model_path = os.path.join(folder_path, model_name)
    torch.save(model.state_dict(), model_path)
    return model_path



def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


if __name__ == "__main__":
    my_app()
